{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.3 word2vec的实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import math\n",
    "import random\n",
    "import sys\n",
    "import time\n",
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.utils.data as Data\n",
    "import d2lzh as d2l"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# sentences: 42068'"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open('data/ptb/ptb.train.txt', 'r') as f:\n",
    "    lines = f.readlines()\n",
    "    # st是sentence的缩写\n",
    "    raw_dataset = [st.split() for st in lines]\n",
    "'# sentences: %d' % len(raw_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# tokens: 24 ['aer', 'banknote', 'berlitz', 'calloway', 'centrust']\n",
      "# tokens: 15 ['pierre', '<unk>', 'N', 'years', 'old']\n",
      "# tokens: 11 ['mr.', '<unk>', 'is', 'chairman', 'of']\n"
     ]
    }
   ],
   "source": [
    "for st in raw_dataset[:3]:\n",
    "    print('# tokens:', len(st), st[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tk是token的缩写\n",
    "counter = collections.Counter([tk for st in raw_dataset for tk in st])\n",
    "counter = dict(filter(lambda x: x[1]>=5, counter.items()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# tokens: 887100'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx_to_token = [tk for tk, _ in counter.items()]\n",
    "token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}\n",
    "dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx] for st in raw_dataset]\n",
    "num_tokens = sum([len(st) for st in dataset])\n",
    "'# tokens: %d' % num_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# tokens: 375413'"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def discard(idx):\n",
    "    # 与均匀分布对比，确定该词是否被剔除\n",
    "    return random.uniform(0, 1) < 1-math.sqrt(\n",
    "        1e-4 / counter[idx_to_token[idx]] * num_tokens\n",
    "    )\n",
    "subsampled_dataset = [[tk for tk in st if not discard(tk)] for st in dataset]\n",
    "'# tokens: %d' % sum([len(st) for st in subsampled_dataset])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# the: before=50770, after=2131'"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def compare_counts(token):\n",
    "    return '# %s: before=%d, after=%d' % (token, sum(\n",
    "        [st.count(token_to_idx[token]) for st in dataset]), sum(\n",
    "        [st.count(token_to_idx[token]) for st in subsampled_dataset])\n",
    "    )\n",
    "compare_counts('the')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'# join: before=45, after=45'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compare_counts('join')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_centers_and_contexts(dataset, max_window_size):\n",
    "    centers, contexts = [], []\n",
    "    for st in dataset:\n",
    "        # 每个句子至少有2个词才能组成一对“中心词-背景词”\n",
    "        if len(st)<2:\n",
    "            continue\n",
    "        # 只要句子长度大于等于2，每个词都要做中心词\n",
    "        centers += st\n",
    "        for center_i in range(len(st)):\n",
    "            window_size = random.randint(1, max_window_size)\n",
    "            indices = list(range(max(0, center_i-window_size), \n",
    "                                min(len(st), center_i+1+window_size)\n",
    "                                ))\n",
    "            # 将中心词排除在背景词外\n",
    "            indices.remove(center_i)\n",
    "            contexts.append([st[idx] for idx in indices])\n",
    "    return centers, contexts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dataset [[0, 1, 2, 3, 4, 5, 6], [7, 8, 9]]\n",
      "center 0 has contexts [1]\n",
      "center 1 has contexts [0, 2]\n",
      "center 2 has contexts [0, 1, 3, 4]\n",
      "center 3 has contexts [2, 4]\n",
      "center 4 has contexts [3, 5]\n",
      "center 5 has contexts [3, 4, 6]\n",
      "center 6 has contexts [4, 5]\n",
      "center 7 has contexts [8, 9]\n",
      "center 8 has contexts [7, 9]\n",
      "center 9 has contexts [7, 8]\n"
     ]
    }
   ],
   "source": [
    "tiny_dataset = [list(range(7)), list(range(7, 10))]\n",
    "print('dataset', tiny_dataset)\n",
    "for center, context in zip(*get_centers_and_contexts(tiny_dataset, 2)):\n",
    "    print('center', center, 'has contexts', context)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_centers, all_contexts = get_centers_and_contexts(subsampled_dataset, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_negatives(all_contexts, sampling_weights, K):\n",
    "    # all_contexts每个词的背景词列表\n",
    "    # sampling_weights每个词词频的0.75次幂\n",
    "    # K噪声词相比于背景词个数的倍数\n",
    "    all_negatives, neg_candidates, i = [], [], 0\n",
    "    # 词表中词的个数\n",
    "    population = list(range(len(sampling_weights)))\n",
    "    for contexts in all_contexts:\n",
    "        negatives = []\n",
    "        while len(negatives) < len(contexts)*K:\n",
    "            if i==len(neg_candidates):\n",
    "                # 从population随机选取k次数据，返回一个列表\n",
    "                # 根据每个词的权重随机生成k个词的索引作为噪声词\n",
    "                # 为了高效计算，可以将k设的稍微大一点\n",
    "                i, neg_candidates = 0, random.choices(\n",
    "                    population, sampling_weights, k=int(1e5)\n",
    "                )\n",
    "            neg, i = neg_candidates[i], i+1\n",
    "            # 噪声词不能是背景词\n",
    "            if neg not in set(contexts):\n",
    "                negatives.append(neg)\n",
    "        all_negatives.append(negatives)\n",
    "    return all_negatives"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "sampling_weights = [counter[w]**0.75 for w in idx_to_token]\n",
    "all_negatives = get_negatives(all_contexts, sampling_weights, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, centers, contexts, negatives):\n",
    "        assert len(centers)==len(contexts)==len(negatives)\n",
    "        self.centers = centers\n",
    "        self.contexts = contexts\n",
    "        self.negatives = negatives\n",
    "    def __getitem__(self, index):\n",
    "        return (self.centers[index], self.contexts[index], self.negatives[index])\n",
    "    def __len__(self):\n",
    "        return len(self.centers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batchify(data):\n",
    "    # 中心词、背景词、噪声词\n",
    "    max_len = max(len(c)+len(n) for _, c, n in data)\n",
    "    centers, contexts_negatives, masks, labels = [], [], [], []\n",
    "    for center, context, negative in data:\n",
    "        cur_len = len(context)+len(negative)\n",
    "        centers += [center]\n",
    "        contexts_negatives += [context+negative+[0]*(max_len-cur_len)]\n",
    "        masks += [[1]*cur_len+[0]*(max_len-cur_len)]\n",
    "        labels += [[1]*len(context)+[0]*(max_len-len(context))]\n",
    "    return (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives), torch.tensor(masks), torch.tensor(labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "centers shape: torch.Size([512, 1])\n",
      "contexts_negatives shape: torch.Size([512, 60])\n",
      "masks shape: torch.Size([512, 60])\n",
      "labels shape: torch.Size([512, 60])\n"
     ]
    }
   ],
   "source": [
    "batch_size = 512\n",
    "num_workers = 4\n",
    "dataset = MyDataset(all_centers, all_contexts, all_negatives)\n",
    "data_iter = Data.DataLoader(dataset, batch_size, shuffle=True, collate_fn=batchify, num_workers=4)\n",
    "for batch in data_iter:\n",
    "    for name, data in zip(['centers', 'contexts_negatives', 'masks', 'labels'], batch):\n",
    "        print(name, 'shape:', data.shape)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[-7.5014e-01, -6.7758e-01,  1.3905e-01, -3.2504e-01],\n",
       "        [ 4.6432e-01, -1.2951e+00, -6.3469e-01,  1.7867e+00],\n",
       "        [-1.4255e+00, -5.1740e-01, -1.8344e+00, -8.6178e-01],\n",
       "        [-4.1812e-01,  9.8486e-01,  8.6274e-01, -1.8278e-01],\n",
       "        [-4.8235e-02, -4.4251e-01,  1.8103e+00,  1.4304e-03],\n",
       "        [ 6.0103e-01,  5.2687e-01, -7.9238e-01, -3.8206e-01],\n",
       "        [-6.4648e-01,  6.1382e-01, -2.6217e-01,  1.8242e+00],\n",
       "        [ 2.1424e-01, -1.2573e+00,  9.9863e-01,  4.7190e-01],\n",
       "        [-1.4787e-01, -1.3340e+00,  9.4021e-03,  2.0213e-01],\n",
       "        [-2.3015e-01, -5.4096e-01,  7.9691e-01,  5.0277e-01],\n",
       "        [-2.4606e-01, -1.6335e+00, -9.0530e-03, -5.7450e-01],\n",
       "        [ 1.7134e+00, -1.9645e+00, -1.8626e-01,  8.8031e-01],\n",
       "        [ 1.7762e+00, -1.2961e-01, -1.2376e+00, -5.2555e-01],\n",
       "        [ 1.7810e-01,  1.5218e+00, -2.2159e-01,  2.5675e-01],\n",
       "        [ 1.5030e+00,  2.1536e-01,  1.1065e+00, -3.9987e-01],\n",
       "        [ 3.3883e-01, -5.9900e-01, -4.5714e-01,  1.7813e+00],\n",
       "        [-2.2987e-01, -1.6290e+00,  1.2975e+00, -5.6441e-01],\n",
       "        [ 6.4336e-01,  4.5049e-02, -1.5265e+00,  5.7984e-01],\n",
       "        [-3.3092e-01,  1.1003e-01,  9.1502e-01, -2.5245e-01],\n",
       "        [ 9.8932e-01,  4.5280e-01, -1.1923e+00, -1.7245e+00]],\n",
       "       requires_grad=True)"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embed = nn.Embedding(num_embeddings=20, embedding_dim=4)\n",
    "embed.weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[ 4.6432e-01, -1.2951e+00, -6.3469e-01,  1.7867e+00],\n",
       "         [-1.4255e+00, -5.1740e-01, -1.8344e+00, -8.6178e-01],\n",
       "         [-4.1812e-01,  9.8486e-01,  8.6274e-01, -1.8278e-01]],\n",
       "\n",
       "        [[-4.8235e-02, -4.4251e-01,  1.8103e+00,  1.4304e-03],\n",
       "         [ 6.0103e-01,  5.2687e-01, -7.9238e-01, -3.8206e-01],\n",
       "         [-6.4648e-01,  6.1382e-01, -2.6217e-01,  1.8242e+00]]],\n",
       "       grad_fn=<EmbeddingBackward>)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.long)\n",
    "embed(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 1, 6])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.ones((2, 1, 4))\n",
    "Y = torch.ones((2, 4, 6))\n",
    "torch.bmm(X, Y).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def skip_gram(center, contexts_and_negatives, embed_v, embed_u):\n",
    "    # 每一个词由背景词向量和中心词向量表示\n",
    "    # 所以需要两个嵌入表示\n",
    "    v = embed_v(center)\n",
    "    u = embed_u(contexts_and_negatives)\n",
    "    # batch, emb, num\n",
    "    pred = torch.bmm(v, u.permute(0, 2, 1))\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SigmoidBinaryCrossEntropyLoss(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(SigmoidBinaryCrossEntropyLoss, self).__init__()\n",
    "    def forward(self, inputs, targets, mask=None):\n",
    "        \"\"\"\n",
    "        input - Tensor shape: (batch_size, len)\n",
    "        target - Tensor of the same shape as input\n",
    "        \"\"\"\n",
    "        inputs, targets, mask = inputs.float(), targets.float(), mask.float()\n",
    "        res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none', weight=mask)\n",
    "        return res.mean(dim=1)\n",
    "loss = SigmoidBinaryCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8740, 1.2100])"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred = torch.tensor([[1.5, 0.3, -1, 2], [1.1, -0.6, 2.2, 0.4]])\n",
    "# 标签变量label中的1和0分别代表背景词和噪声词\n",
    "label = torch.tensor([[1, 0, 0, 0], [1, 1, 0, 0]])\n",
    "# 掩码变量\n",
    "mask = torch.tensor([[1, 1, 1, 1], [1, 1, 1, 0]])\n",
    "loss(pred, label, mask)*mask.shape[1]/mask.float().sum(dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.8740\n",
      "1.2100\n"
     ]
    }
   ],
   "source": [
    "def sigmd(x):\n",
    "    return -math.log(1/(1+math.exp(-x)))\n",
    "# 1-sigmd(x)=sigmd(-x)\n",
    "# 背景词部分计算sigmd(x),噪声词部分计算sigmd(-x)\n",
    "print('%.4f' % ((sigmd(1.5)+sigmd(-0.3)+sigmd(1)+sigmd(-2))/4))\n",
    "print('%.4f' % ((sigmd(1.1)+sigmd(-0.6)+sigmd(-2.2))/3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_size = 100\n",
    "net = nn.Sequential(\n",
    "    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size), \n",
    "    nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(net, lr, num_epochs):\n",
    "    device = 'cuda'\n",
    "    net = net.to(device)\n",
    "    optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "    for epoch in range(num_epochs):\n",
    "        start, l_sum, n = time.time(), 0.0, 0\n",
    "        for batch in data_iter:\n",
    "            center, context_negative, mask, label = [d.to(device) for d in batch]\n",
    "            pred = skip_gram(center, context_negative, net[0], net[1])\n",
    "            # 使用掩码变量mask来避免填充项对损失函数计算的影响\n",
    "            # 一个batch的平均loss\n",
    "            l = (loss(pred.view(label.shape), label, mask) * mask.shape[1]/mask.float().sum(dim=1)).mean()\n",
    "            optimizer.zero_grad()\n",
    "            l.backward()\n",
    "            optimizer.step()\n",
    "            l_sum += l.cpu().item()\n",
    "            n += 1\n",
    "        print('epoch %d, loss %.2f, time %.2f' % (epoch+1, l_sum/n, time.time()-start))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 1.96, time 14.15\n",
      "epoch 2, loss 0.62, time 13.38\n",
      "epoch 3, loss 0.45, time 13.23\n",
      "epoch 4, loss 0.39, time 13.57\n",
      "epoch 5, loss 0.37, time 13.39\n",
      "epoch 6, loss 0.35, time 13.27\n",
      "epoch 7, loss 0.34, time 13.31\n",
      "epoch 8, loss 0.33, time 13.50\n",
      "epoch 9, loss 0.32, time 13.92\n",
      "epoch 10, loss 0.32, time 13.38\n"
     ]
    }
   ],
   "source": [
    "train(net, 0.01, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.486: bugs\n",
      "cosine sim=0.484: computers\n",
      "cosine sim=0.444: mips\n"
     ]
    }
   ],
   "source": [
    "def get_similar_tokens(query_tokens, k, embed):\n",
    "    W = embed.weight.data\n",
    "    x = W[token_to_idx[query_tokens]]\n",
    "    # 添加1e-9是为了数值稳定性\n",
    "    cos = torch.matmul(W, x)/(torch.sum(W*W, dim=1)*(torch.sum(x*x)+1e-9)).sqrt()\n",
    "    _, topk = torch.topk(cos, k=k+1)\n",
    "    topk = topk.cpu().numpy()\n",
    "    # 除去输入词\n",
    "    for i in topk[1:]:\n",
    "        print('cosine sim=%.3f: %s' % (cos[i], (idx_to_token[i])))\n",
    "get_similar_tokens('chip', 3, net[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.6 求近义词和类比词"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting torchtext==0.4.0\n",
      "  Downloading torchtext-0.4.0-py3-none-any.whl (53 kB)\n",
      "\u001b[K     |████████████████████████████████| 53 kB 896 kB/s eta 0:00:011\n",
      "\u001b[?25hRequirement already satisfied: six in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from torchtext==0.4.0) (1.15.0)\n",
      "Requirement already satisfied: tqdm in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from torchtext==0.4.0) (4.49.0)\n",
      "Requirement already satisfied: numpy in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from torchtext==0.4.0) (1.19.1)\n",
      "Requirement already satisfied: torch in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from torchtext==0.4.0) (1.1.0)\n",
      "Requirement already satisfied: requests in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from torchtext==0.4.0) (2.24.0)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from requests->torchtext==0.4.0) (1.25.10)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from requests->torchtext==0.4.0) (2.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from requests->torchtext==0.4.0) (2020.6.20)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /data2/wzy/SoftWare/anaconda3/envs/PyTorch/lib/python3.7/site-packages (from requests->torchtext==0.4.0) (3.0.4)\n",
      "Installing collected packages: torchtext\n",
      "Successfully installed torchtext-0.4.0\n"
     ]
    }
   ],
   "source": [
    "! pip install torchtext==0.4.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['charngram.100d', 'fasttext.en.300d', 'fasttext.simple.300d', 'glove.42B.300d', 'glove.840B.300d', 'glove.twitter.27B.25d', 'glove.twitter.27B.50d', 'glove.twitter.27B.100d', 'glove.twitter.27B.200d', 'glove.6B.50d', 'glove.6B.100d', 'glove.6B.200d', 'glove.6B.300d'])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torchtext.vocab as vocab\n",
    "vocab.pretrained_aliases.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['glove.42B.300d',\n",
       " 'glove.840B.300d',\n",
       " 'glove.twitter.27B.25d',\n",
       " 'glove.twitter.27B.50d',\n",
       " 'glove.twitter.27B.100d',\n",
       " 'glove.twitter.27B.200d',\n",
       " 'glove.6B.50d',\n",
       " 'glove.6B.100d',\n",
       " 'glove.6B.200d',\n",
       " 'glove.6B.300d']"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "[key for key in vocab.pretrained_aliases.keys() if 'glove' in key]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "data/glove/glove.6B.zip: 862MB [08:31, 1.69MB/s]                               \n",
      "100%|█████████▉| 399999/400000 [00:26<00:00, 14880.24it/s]\n"
     ]
    }
   ],
   "source": [
    "cache_dir = 'data/glove'\n",
    "glove = vocab.GloVe(name='6B', dim=50, cache=cache_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "一共包含400000个词。\n"
     ]
    }
   ],
   "source": [
    "print('一共包含%d个词。' % len(glove.stoi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3366, 'beautiful')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "glove.stoi['beautiful'], glove.itos[3366]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def knn(W, x, k):\n",
    "    # 添加的1e-9是为了数值稳定性\n",
    "    cos = torch.matmul(W, x.view((-1,))) / ((torch.sum(W*W, dim=1)+1e-9).sqrt()*torch.sum(x*x).sqrt())\n",
    "    _, topk = torch.topk(cos, k=k)\n",
    "    topk = topk.cpu().numpy()\n",
    "    return topk, [cos[i].item() for i in topk]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_similar_tokens(query_token, k, embed):\n",
    "    topk, cos = knn(embed.vectors, embed.vectors[embed.stoi[query_token]], k+1)\n",
    "    # 除去输入词\n",
    "    for i, c in zip(topk[1:], cos[1:]):\n",
    "        print('cosine sim=%.3f: %s' % (c, (embed.itos[i])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.856: chips\n",
      "cosine sim=0.749: intel\n",
      "cosine sim=0.749: electronics\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('chip', 3, glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.839: babies\n",
      "cosine sim=0.800: boy\n",
      "cosine sim=0.792: girl\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('baby', 3, glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cosine sim=0.921: lovely\n",
      "cosine sim=0.893: gorgeous\n",
      "cosine sim=0.830: wonderful\n"
     ]
    }
   ],
   "source": [
    "get_similar_tokens('beautiful', 3, glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_analogy(token_a, token_b, token_c, embed):\n",
    "    vecs = [embed.vectors[embed.stoi[t]] for t in [token_a, token_b, token_c]]\n",
    "    x = vecs[1] - vecs[0] + vecs[2]\n",
    "    topk, cos = knn(embed.vectors, x, 1)\n",
    "    return embed.itos[topk[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'daughter'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('man', 'woman', 'son', glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'japan'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('beijing', 'china', 'tokyo', glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'biggest'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('bad', 'worst', 'big', glove)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'went'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_analogy('do', 'did', 'go', glove)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.7 文本情感分类：使用循环神经网络"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import os\n",
    "import random\n",
    "import tarfile\n",
    "import torch\n",
    "from torch import nn\n",
    "import torchtext.vocab as Vocab\n",
    "import torch.utils.data as Data\n",
    "import sys\n",
    "import d2lzh as d2l\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "DATA_ROOT = 'data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "从压缩包解压...\n"
     ]
    }
   ],
   "source": [
    "fname = os.path.join(DATA_ROOT, 'aclImdb_v1.tar.gz')\n",
    "if not os.path.exists(os.path.join(DATA_ROOT, 'aclImdb')):\n",
    "    print('从压缩包解压...')\n",
    "    with tarfile.open(fname, 'r') as f:\n",
    "        f.extractall(DATA_ROOT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 12500/12500 [00:00<00:00, 20482.38it/s]\n",
      "100%|██████████| 12500/12500 [00:01<00:00, 11169.05it/s]\n",
      "100%|██████████| 12500/12500 [00:00<00:00, 21138.34it/s]\n",
      "100%|██████████| 12500/12500 [00:00<00:00, 20492.99it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "def read_imdb(folder='train', data_root='data/aclImdb'):\n",
    "    data = []\n",
    "    for label in ['pos', 'neg']:\n",
    "        folder_name = os.path.join(data_root, folder, label)\n",
    "        # 读取路径下所有文件\n",
    "        for file in tqdm(os.listdir(folder_name)):\n",
    "            with open(os.path.join(folder_name, file), 'rb') as f:\n",
    "                review = f.read().decode('utf-8').replace('\\n', '').lower()\n",
    "                data.append([review, 1 if label=='pos' else 0])\n",
    "    random.shuffle(data)\n",
    "    return data\n",
    "train_data, test_data = read_imdb('train'), read_imdb('test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_tokenized_imdb(data):\n",
    "    \"\"\"\n",
    "    data: list of [string, label]\n",
    "    \"\"\"\n",
    "    def tokenizer(text):\n",
    "        return [tok.lower() for tok in text.split(' ')]\n",
    "    return [tokenizer(review) for review, _ in data]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('# words in vocab:', 46152)"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_vocab_imdb(data):\n",
    "    tokenized_data = get_tokenized_imdb(data)\n",
    "    counter = collections.Counter([tk for st in tokenized_data for tk in st])\n",
    "    return Vocab.Vocab(counter, min_freq=5)\n",
    "vocab = get_vocab_imdb(train_data)\n",
    "'# words in vocab:', len(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_imdb(data, vocab):\n",
    "    # 将每条评论通过截断或者补0，使得长度变成500\n",
    "    max_l = 500\n",
    "    def pad(x):\n",
    "        return x[:max_l] if len(x)>max_l else x+[0]*(max_l-len(x))\n",
    "    tokenized_data = get_tokenized_imdb(data)\n",
    "    features = torch.tensor([pad([vocab.stoi[word] for word in words]) for words in tokenized_data])\n",
    "    labels = torch.tensor([score for _, score in data])\n",
    "    return features, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "train_set = Data.TensorDataset(*preprocess_imdb(train_data, vocab))\n",
    "test_set = Data.TensorDataset(*preprocess_imdb(test_data, vocab))\n",
    "train_iter = Data.DataLoader(train_set, batch_size, shuffle=True)\n",
    "test_iter = Data.DataLoader(test_set, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X torch.Size([64, 500]) y torch.Size([64])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "('#batches:', 391)"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "for X, y in train_iter:\n",
    "    print('X', X.shape, 'y', y.shape)\n",
    "    break\n",
    "'#batches:', len(train_iter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "class BiRNN(nn.Module):\n",
    "    def __init__(self, vocab, embed_size, num_hiddens, num_layers):\n",
    "        super(BiRNN, self).__init__()\n",
    "        self.embedding = nn.Embedding(len(vocab), embed_size)\n",
    "        # bidirection设为True即得到双向循环神经网络\n",
    "        self.encoder = nn.LSTM(input_size=embed_size, \n",
    "                              hidden_size=num_hiddens, \n",
    "                              num_layers=num_layers, \n",
    "                              bidirectional=True)\n",
    "        # 初始时间步和最终时间步的隐藏状态作为全连接层输入\n",
    "        self.decoder = nn.Linear(4*num_hiddens, 2)\n",
    "    def forward(self, inputs):\n",
    "        # inputs的形状是（批量大小，词数）\n",
    "        # 因为LSTM需要将序列长度（seq_len）作为第一维\n",
    "        # 所以将输入转置后再提取词特征\n",
    "        # 输出形状为（词数、批量大小、词向量维度）\n",
    "        embeddings = self.embedding(inputs.permute(1, 0))\n",
    "        # rnn.LSTM只传入输入embeddings，因此只返回最后一层的\n",
    "        # 隐藏层在各时间步的隐藏状态\n",
    "        # outputs形状是（词数，批量大小，2*隐层单元个数）\n",
    "        # 乘2是因为双向LSTM\n",
    "        outputs, _ = self.encoder(embeddings)\n",
    "        # 连结初始时间步和最终时间步的隐藏状态作为全连接层输入\n",
    "        # 它的形状为（批量大小，4*隐藏单元个数）\n",
    "        encoding = torch.cat((outputs[0], outputs[-1]), -1)\n",
    "        outs = self.decoder(encoding)\n",
    "        return outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_size, num_hiddens, num_layers = 100, 100, 2\n",
    "net = BiRNN(vocab, embed_size, num_hiddens, num_layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████▉| 399999/400000 [00:43<00:00, 9198.13it/s] \n"
     ]
    }
   ],
   "source": [
    "glove_vocab = Vocab.GloVe(name='6B', dim=100, cache=os.path.join(DATA_ROOT, 'glove'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 21202 oov words.\n"
     ]
    }
   ],
   "source": [
    "def load_pretrained_embedding(words, pretrained_vocab):\n",
    "    \"\"\"\n",
    "    从预训练好的vocab中提取出words对应的词向量\n",
    "    \"\"\"\n",
    "    embed = torch.zeros(len(words), pretrained_vocab.vectors[0].shape[0])\n",
    "    oov_count = 0\n",
    "    for i, word in enumerate(words):\n",
    "        try:\n",
    "            idx = pretrained_vocab.stoi[word]\n",
    "            embed[i,:] = pretrained_vocab.vectors[idx]\n",
    "        except KeyError:\n",
    "            oov_count += 1\n",
    "    # 没有索引的单词个数\n",
    "    if oov_count > 0:\n",
    "        print('There are %d oov words.' % oov_count)\n",
    "    return embed\n",
    "net.embedding.weight.data.copy_(load_pretrained_embedding(vocab.itos, glove_vocab))\n",
    "# 直接加载预训练好的，所以不需要更新\n",
    "net.embedding.weight.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Let's use 2 GPUs!\n",
      "training on cuda\n",
      "epoch 1, loss 0.6009, train acc 0.675, test acc 0.794, time 103.1 sec\n",
      "epoch 2, loss 0.2098, train acc 0.811, test acc 0.838, time 96.3 sec\n",
      "epoch 3, loss 0.1183, train acc 0.844, test acc 0.851, time 99.5 sec\n",
      "epoch 4, loss 0.0772, train acc 0.870, test acc 0.858, time 98.7 sec\n",
      "epoch 5, loss 0.0542, train acc 0.890, test acc 0.853, time 99.4 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.01, 5\n",
    "# 要过滤掉不计算梯度的embedding参数\n",
    "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)\n",
    "loss = nn.CrossEntropyLoss()\n",
    "d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_sentiment(net, vocab, sentence):\n",
    "    \"\"\"\n",
    "    sentence是词语的列表\n",
    "    \"\"\"\n",
    "    device = list(net.parameters())[0].device\n",
    "    sentence = torch.tensor([vocab.stoi[word] for word in sentence], device=device)\n",
    "    label = torch.argmax(net(sentence.view((1, -1))), dim=1)\n",
    "    return 'positive' if label.item()==1 else 'negative'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'positive'"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'great'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'negative'"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'bad'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.8 文本情感分类：使用卷积神经网络（textCNN）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "from torch import nn\n",
    "import torchtext.vocab as Vocab\n",
    "import torch.nn.functional as F\n",
    "import sys\n",
    "import d2lzh as d2l\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "DATA_ROOT = 'data'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def corr1d(X, K):\n",
    "    w = K.shape[0]\n",
    "    Y = torch.zeros(X.shape[0]-w+1)\n",
    "    for i in range(Y.shape[0]):\n",
    "        Y[i] = (X[i: i+w]*K).sum()\n",
    "    return Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2.,  5.,  8., 11., 14., 17.])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X, K = torch.tensor([0, 1, 2, 3, 4, 5, 6]), torch.tensor([1, 2])\n",
    "corr1d(X, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 2.,  8., 14., 20., 26., 32.])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def corr1d_multi_in(X, K):\n",
    "    # 首先沿着X和K的第0维（通道维）遍历并计算一维互相关结果\n",
    "    # 然后将所有结果堆叠起来沿着第0维累加\n",
    "    return torch.stack([corr1d(x, k) for x, k in zip(X, K)]).sum(dim=0)\n",
    "X = torch.tensor([\n",
    "    [0, 1, 2, 3, 4, 5, 6], \n",
    "    [1, 2, 3, 4, 5, 6, 7], \n",
    "    [2, 3, 4, 5, 6, 7, 8]\n",
    "])\n",
    "K = torch.tensor([\n",
    "    [1, 2], \n",
    "    [3, 4], \n",
    "    [-1, -3]\n",
    "])\n",
    "corr1d_multi_in(X, K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GlobalMaxPool1d(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(GlobalMaxPool1d, self).__init__()\n",
    "    def forward(self, x):\n",
    "        # x shape: (batch_size, channel, seq_len)\n",
    "        # return shape: (batch_size, channel, 1)\n",
    "        return F.max_pool1d(x, kernel_size=x.shape[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 12500/12500 [00:00<00:00, 22277.48it/s]\n",
      "100%|██████████| 12500/12500 [00:00<00:00, 19936.30it/s]\n",
      "100%|██████████| 12500/12500 [00:00<00:00, 23050.33it/s]\n",
      "100%|██████████| 12500/12500 [00:00<00:00, 21506.03it/s]\n"
     ]
    }
   ],
   "source": [
    "import torch.utils.data as Data\n",
    "batch_size = 64\n",
    "train_data = d2l.read_imdb('train', data_root=os.path.join(DATA_ROOT, 'aclImdb'))\n",
    "test_data = d2l.read_imdb('test', data_root=os.path.join(DATA_ROOT, 'aclImdb'))\n",
    "vocab = d2l.get_vocab_imdb(train_data)\n",
    "train_set = Data.TensorDataset(*d2l.preprocess_imdb(train_data, vocab))\n",
    "test_set = Data.TensorDataset(*d2l.preprocess_imdb(test_data, vocab))\n",
    "train_iter = Data.DataLoader(train_set, batch_size, shuffle=True)\n",
    "test_iter = Data.DataLoader(test_set, batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TextCNN(nn.Module):\n",
    "    def __init__(self, vocab, embed_size, kernel_sizes, num_channels):\n",
    "        super(TextCNN, self).__init__()\n",
    "        self.embedding = nn.Embedding(len(vocab), embed_size)\n",
    "        # 不参与训练的嵌入层\n",
    "        self.constant_embedding = nn.Embedding(len(vocab), embed_size)\n",
    "        self.dropout = nn.Dropout(0.5)\n",
    "        self.decoder = nn.Linear(sum(num_channels), 2)\n",
    "        # 时序最大化层没有权重，所以可以共用一个实例\n",
    "        self.pool = GlobalMaxPool1d()\n",
    "        # 创建多个一维卷积层\n",
    "        self.convs = nn.ModuleList()\n",
    "        for c, k in zip(num_channels, kernel_sizes):\n",
    "            self.convs.append(nn.Conv1d(in_channels=2*embed_size, \n",
    "                                       out_channels=c, \n",
    "                                       kernel_size=k))\n",
    "    def forward(self, inputs):\n",
    "        # 将两个形状是(批量大小,词数,词向量维度)的嵌入层的输出按词向量连结\n",
    "        # (batch, seq_len, 2*embed_size)\n",
    "        embeddings = torch.cat((self.embedding(inputs), \n",
    "                               self.constant_embedding(inputs)), dim=2)\n",
    "        # 根据Conv1D要求的输入格式，将词向量维，即一维卷积层的通道维\n",
    "        # 即词向量那一维变换到前一维\n",
    "        embeddings = embeddings.permute(0, 2, 1)\n",
    "        # 对于每个一维卷积，在时序最大池化后会得到一个形状为(批量大小,通道大小,1)的\n",
    "        # Tensor。使用flatten函数去掉最后一维，然后在通道维上连结\n",
    "        encoding = torch.cat([self.pool(F.relu(conv(embeddings))).squeeze(-1) for conv in self.convs], dim=1)\n",
    "        # 应用丢弃法后使用全连接层得到输出\n",
    "        outputs = self.decoder(self.dropout(encoding))\n",
    "        return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "embed_size = 100\n",
    "kernel_size = [3, 4, 5]\n",
    "nums_channels = [100, 100, 100]\n",
    "net = TextCNN(vocab, embed_size, kernel_size, nums_channels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "There are 21202 oov words.\n",
      "There are 21202 oov words.\n"
     ]
    }
   ],
   "source": [
    "glove_vocab = Vocab.GloVe(name='6B', dim=100, cache=os.path.join(DATA_ROOT, 'glove'))\n",
    "net.embedding.weight.data.copy_(d2l.load_pretrained_embedding(vocab.itos, glove_vocab))\n",
    "net.constant_embedding.weight.data.copy_(d2l.load_pretrained_embedding(vocab.itos, glove_vocab))\n",
    "net.constant_embedding.weight.requires_grad = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Let's use 2 GPUs!\n",
      "training on cuda\n",
      "epoch 1, loss 0.4848, train acc 0.755, test acc 0.809, time 26.7 sec\n",
      "epoch 2, loss 0.1607, train acc 0.861, test acc 0.870, time 21.1 sec\n",
      "epoch 3, loss 0.0685, train acc 0.919, test acc 0.876, time 21.3 sec\n",
      "epoch 4, loss 0.0294, train acc 0.958, test acc 0.864, time 21.0 sec\n",
      "epoch 5, loss 0.0127, train acc 0.978, test acc 0.863, time 21.1 sec\n"
     ]
    }
   ],
   "source": [
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=lr)\n",
    "loss = nn.CrossEntropyLoss()\n",
    "d2l.train(train_iter, test_iter, net, loss, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'positive'"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d2l.predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'great'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'negative'"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "d2l.predict_sentiment(net, vocab, ['this', 'movie', 'is', 'so', 'bad'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9.12 机器翻译"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import os\n",
    "import io\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "import torchtext.vocab as Vocab\n",
    "import torch.utils.data as Data\n",
    "import sys\n",
    "import d2lzh as d2l\n",
    "PAD, BOS, EOS = '<pad>', '<bos>', '<eos>'\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将一个序列中所有的词记录在all_tokens中以便之后构造词典\n",
    "# 然后在该序列后面添加PAD直到序列长度为max_seq_len\n",
    "# 然后将序列保存在all_seqs中\n",
    "def process_one_seq(seq_tokens, all_tokens, all_seqs, max_seq_len):\n",
    "    all_tokens.extend(seq_tokens)\n",
    "    seq_tokens += [EOS] + [PAD] * (max_seq_len-len(seq_tokens)-1)\n",
    "    all_seqs.append(seq_tokens)\n",
    "# 使用所有的词构造词典。并将所有序列中的词变为索引后构造Tensor\n",
    "def build_data(all_tokens, all_seqs):\n",
    "    vocab = Vocab.Vocab(collections.Counter(all_tokens), specials=[PAD, BOS, EOS])\n",
    "    indices = [[vocab.stoi[w] for w in seq] for seq in all_seqs]\n",
    "    return vocab, torch.tensor(indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data(max_seq_len):\n",
    "    # in和out分别是input和output的缩写\n",
    "    in_tokens, out_tokens, in_seqs, out_seqs = [], [], [], []\n",
    "    with io.open('data/fr-en-small.txt') as f:\n",
    "        lines = f.readlines()\n",
    "    for line in lines:\n",
    "        in_seq, out_seq = line.rstrip().split('\\t')\n",
    "        in_seq_tokens, out_seq_tokens = in_seq.split(' '), out_seq.split(' ')\n",
    "        if max(len(in_seq_tokens), len(out_seq_tokens)) > max_seq_len - 1:\n",
    "            # 如果加上EOS后长于max_seq_len，则忽略掉此样本\n",
    "            continue\n",
    "        process_one_seq(in_seq_tokens, in_tokens, in_seqs, max_seq_len)\n",
    "        process_one_seq(out_seq_tokens, out_tokens, out_seqs, max_seq_len)\n",
    "    in_vocab, in_data = build_data(in_tokens, in_seqs)\n",
    "    out_vocab, out_data = build_data(out_tokens, out_seqs)\n",
    "    return in_vocab, out_vocab, Data.TensorDataset(in_data, out_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([ 5,  4, 45,  3,  2,  0,  0]), tensor([ 8,  4, 27,  3,  2,  0,  0]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "max_seq_len = 7\n",
    "in_vocab, out_vocab, dataset = read_data(max_seq_len)\n",
    "dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Encoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, drop_prob=0, **kwargs):\n",
    "        super(Encoder, self).__init__(**kwargs)\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.rnn = nn.GRU(embed_size, num_hiddens, num_layers, dropout=drop_prob)\n",
    "    def forward(self, inputs, state):\n",
    "        # 输入形状是(批量大小，时间步数)\n",
    "        # 将输出互换样本维和时间步维\n",
    "        # (seq_len, batch_size, input_size)\n",
    "        embedding = self.embedding(inputs.long()).permute(1, 0, 2)\n",
    "        return self.rnn(embedding, state)\n",
    "    def begin_state(self):\n",
    "        # 隐藏状态初始化为None时PyTorch会自动初始化为0\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoder = Encoder(vocab_size=10, embed_size=8, num_hiddens=16, num_layers=2)\n",
    "output, state = encoder(torch.zeros((4, 7)), state=encoder.begin_state())\n",
    "# GRU的state是h，而LSTM的是一个元组(h,c)\n",
    "output.shape, state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention_model(input_size, attention_size):\n",
    "    model = nn.Sequential(\n",
    "        nn.Linear(input_size, attention_size, bias=False), \n",
    "        nn.Tanh(), \n",
    "        nn.Linear(attention_size, 1, bias=False)\n",
    "    )\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def attention_forward(model, enc_states, dec_state):\n",
    "    \"\"\"\n",
    "    enc_states: (时间步数，批量大小，隐藏单元个数)\n",
    "    dec_state: (批量大小，隐藏单元个数)\n",
    "    \"\"\"\n",
    "    # 将解码器隐藏状态广播到和编码器隐藏状态形状相同后进行连结\n",
    "    dec_states = dec_state.unsqueeze(dim=0).expand_as(enc_states)\n",
    "    enc_and_dec_states = torch.cat((enc_states, dec_states), dim=2)\n",
    "    # 形状为(时间步数，批量大小，1)\n",
    "    e = model(enc_and_dec_states)\n",
    "    # 在时间步维度做softmax运算\n",
    "    alpha = F.softmax(e, dim=0)\n",
    "    # 返回背景变量\n",
    "    return (alpha*enc_states).sum(dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([4, 8])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "seq_len, batch_size, num_hiddens = 10, 4, 8\n",
    "model = attention_model(2*num_hiddens, 10)\n",
    "enc_states = torch.zeros((seq_len, batch_size, num_hiddens))\n",
    "dec_state = torch.zeros((batch_size, num_hiddens))\n",
    "attention_forward(model, enc_states, dec_state).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Decoder(nn.Module):\n",
    "    def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, attention_size, drop_prob=0):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.embedding = nn.Embedding(vocab_size, embed_size)\n",
    "        self.attention = attention_model(2*num_hiddens, attention_size)\n",
    "        # GRU的输入包含attention输出的c和实际输入\n",
    "        # 所以尺寸是num_hiddens+embed_size\n",
    "        self.rnn = nn.GRU(num_hiddens+embed_size, num_hiddens, num_layers, dropout=drop_prob)\n",
    "        self.out = nn.Linear(num_hiddens, vocab_size)\n",
    "    def forward(self, cur_input, state, enc_states):\n",
    "        \"\"\"\n",
    "        cur_input shape: (batch, )\n",
    "        state shape: (num_layers, batch, num_hiddens)\n",
    "        \"\"\"\n",
    "        # 使用注意力机制计算背景向量\n",
    "        c = attention_forward(self.attention, enc_states, state[-1])\n",
    "        # 将嵌入后的输入和背景向量在特征维连结\n",
    "        # (批量大小，num_hiddens+embed_size)\n",
    "        input_and_c = torch.cat((self.embedding(cur_input), c), dim=1)\n",
    "        # 为输入和背景向量的连结增加时间步维，时间步个数为1\n",
    "        output, state = self.rnn(input_and_c.unsqueeze(0), state)\n",
    "        # 移除时间步维，输出形状为(批量大小，输出词典大小)\n",
    "        output = self.out(output).squeeze(dim=0)\n",
    "        return output, state\n",
    "    def begin_state(self, enc_state):\n",
    "        # 直接将编码器最终时间步的隐藏状态作为解码器的初始隐藏状态\n",
    "        return enc_state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def batch_loss(encoder, decoder, X, Y, loss):\n",
    "    batch_size = X.shape[0]\n",
    "    enc_state = encoder.begin_state()\n",
    "    enc_outputs, enc_state = encoder(X, enc_state)\n",
    "    # 初始化解码器的隐藏状态\n",
    "    dec_state = decoder.begin_state(enc_state)\n",
    "    # 解码器在最初时间步的输入是BOS\n",
    "    dec_input = torch.tensor([out_vocab.stoi[BOS]]*batch_size)\n",
    "    # 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失\n",
    "    mask, num_not_pad_tokens = torch.ones(batch_size,), 0\n",
    "    l = torch.tensor([0.0])\n",
    "    # Y shape: (batch, seq_len)\n",
    "    for y in Y.permute(1, 0):\n",
    "        dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)\n",
    "        l = l + (mask*loss(dec_output, y)).sum()\n",
    "        # 使用强制教学\n",
    "        dec_input = y\n",
    "        num_not_pad_tokens += mask.sum().item()\n",
    "        # EOS后面全是PAD，下面一行保证一旦遇到EOS接下来的循环中mask就一直是0\n",
    "        mask = mask * (y != out_vocab.stoi[EOS]).float()\n",
    "    return l / num_not_pad_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(encoder, decoder, dataset, lr, batch_size, num_epochs):\n",
    "    enc_optimizer = torch.optim.Adam(encoder.parameters(), lr=lr)\n",
    "    dec_optimizer = torch.optim.Adam(decoder.parameters(), lr=lr)\n",
    "    loss = nn.CrossEntropyLoss(reduction='none')\n",
    "    data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)\n",
    "    for epoch in range(num_epochs):\n",
    "        l_sum = 0.0\n",
    "        for X, Y in data_iter:\n",
    "            enc_optimizer.zero_grad()\n",
    "            dec_optimizer.zero_grad()\n",
    "            l = batch_loss(encoder, decoder, X, Y, loss)\n",
    "            l.backward()\n",
    "            enc_optimizer.step()\n",
    "            dec_optimizer.step()\n",
    "            l_sum += l.item()\n",
    "        if (epoch+1) % 10 == 0:\n",
    "            print('epoch %d, loss %.3f' % (epoch+1, l_sum/len(data_iter)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 10, loss 0.462\n",
      "epoch 20, loss 0.215\n",
      "epoch 30, loss 0.114\n",
      "epoch 40, loss 0.105\n",
      "epoch 50, loss 0.039\n"
     ]
    }
   ],
   "source": [
    "embed_size, num_hiddens, num_layers = 64, 64, 2\n",
    "attention_size = 10\n",
    "drop_prob = 0.5\n",
    "lr = 0.01\n",
    "batch_size = 2\n",
    "num_epochs = 50\n",
    "encoder = Encoder(len(in_vocab), embed_size, num_hiddens, num_layers, drop_prob)\n",
    "decoder = Decoder(len(out_vocab), embed_size, num_hiddens, num_layers, attention_size, drop_prob)\n",
    "train(encoder, decoder, dataset, lr, batch_size, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate(encoder, decoder, input_seq, max_seq_len):\n",
    "    in_tokens = input_seq.split(' ')\n",
    "    in_tokens += [EOS]+[PAD]*(max_seq_len-len(in_tokens)-1)\n",
    "    # batch=1\n",
    "    enc_input = torch.tensor([[in_vocab.stoi[tk] for tk in in_tokens]])\n",
    "    enc_state = encoder.begin_state()\n",
    "    enc_output, enc_state = encoder(enc_input, enc_state)\n",
    "    dec_input = torch.tensor([out_vocab.stoi[BOS]])\n",
    "    dec_state = decoder.begin_state(enc_state)\n",
    "    output_tokens = []\n",
    "    for _ in range(max_seq_len):\n",
    "        dec_output, dec_state = decoder(dec_input, dec_state, enc_output)\n",
    "        pred = dec_output.argmax(dim=1)\n",
    "        pred_token = out_vocab.itos[int(pred.item())]\n",
    "        # 当任一时间步搜索出现EOS时，输出序列即完成\n",
    "        if pred_token == EOS:\n",
    "            break\n",
    "        else:\n",
    "            output_tokens.append(pred_token)\n",
    "            dec_input = pred\n",
    "    return output_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['they', 'are', 'watching', '.']"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_seq = 'ils regardent .'\n",
    "translate(encoder, decoder, input_seq, max_seq_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bleu(pred_tokens, label_tokens, k):\n",
    "    len_pred, len_label = len(pred_tokens), len(label_tokens)\n",
    "    score = math.exp(min(0, 1-len_label/len_pred))\n",
    "    for n in range(1, k+1):\n",
    "        num_matches, label_subs = 0, collections.defaultdict(int)\n",
    "        for i in range(len_label-n+1):\n",
    "            label_subs[''.join(label_tokens[i: i+n])] += 1\n",
    "        for i in range(len_pred-n+1):\n",
    "            if label_subs[''.join(pred_tokens[i: i+n])] > 0:\n",
    "                num_matches += 1\n",
    "                label_subs[''.join(pred_tokens[i: i+n])] -= 1\n",
    "        score *= math.pow(num_matches/(len_pred-n+1), math.pow(0.5, n))\n",
    "    return score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score(input_seq, label_seq, k):\n",
    "    pred_tokens = translate(encoder, decoder, input_seq, max_seq_len)\n",
    "    label_tokens = label_seq.split(' ')\n",
    "    print('bleu %.3f, predict: %s' % (bleu(pred_tokens, label_tokens, k), \n",
    "                                     ' '.join(pred_tokens)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bleu 1.000, predict: they are watching .\n"
     ]
    }
   ],
   "source": [
    "score('ils regardent .', 'they are watching .', k=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bleu 0.658, predict: they are arguing .\n"
     ]
    }
   ],
   "source": [
    "score('ils sont canadienne .', 'they are canadian .', k=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:PyTorch] *",
   "language": "python",
   "name": "conda-env-PyTorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
